import torch.nn as nn


def rnn_factory(rnn_type, **kwargs):
    rnn = getattr(nn, rnn_type)(**kwargs)
    return rnn
